
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F


class CQDTransEA(nn.Module):
    """
    TransE + Attribute Model
    """

    def __init__(self,
                 nentity,
                 rank,
                 nattr,
                 nrelation,
                 p_norm=1,
                 use_attributes=True,
                 do_sigmoid=True,
                 
                 *args,
                 **kwargs
                 ):
        super(CQDTransEA, self).__init__(*args, **kwargs)

        self.p_norm = p_norm
        self.use_attributes = use_attributes
        self.do_sigmoid = do_sigmoid
        self.nentity=nentity
        self.nrelation=nrelation
        self.rank=rank
        self.nattr=nattr

        self.ent_embeddings = nn.Embedding(self.nentity, self.rank)
        self.rel_embeddings = nn.Embedding(self.nrelation, self.rank)

        nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
        nn.init.xavier_uniform_(self.rel_embeddings.weight.data)

        self.attr_embeddings = nn.Embedding(self.nattr, self.rank)
        nn.init.xavier_uniform_(self.attr_embeddings.weight.data)
        self.b = nn.Embedding(self.nattr, 1)

    def _calc(self, h, r, t):
        score = (h + r) - t
        return torch.norm(score, self.p_norm, -1).flatten()

    def score_rel(self, data):
        h = self.ent_embeddings(data['batch_h'])
        r = self.rel_embeddings(data['batch_r'])
        t = self.ent_embeddings(data['batch_t'])
        return self._calc(h, r, t)

    def score_attr(self, data):
        if "batch_e" not in data:
            return torch.FloatTensor().to(device=data['batch_h'].device)
        e = self.ent_embeddings(data['batch_e'])
        v = data['batch_v']
        preds = self.predict_attribute_values(e, data['batch_a'])
        return torch.stack((preds, v), dim=-1)

    def score(self, data_entity,data_att):
        return -self.score_rel(data_entity), self.score_attr(data_att)

    def predict_attribute_values(self, e_emb, attributes):
        # Predict attribute values for a batch
        # e_emb: [B, E], attributes: [B]
        # returns [B]
        a = self.attr_embeddings(attributes)
        b = self.b(attributes).squeeze()
        predictions = torch.sum(e_emb * a, dim=-1) + b
        if self.do_sigmoid:
            predictions = torch.sigmoid(predictions)
        return predictions

    def predict_for_test(self,entity,attribution):
        entity_embedding=self.ent_embeddings(entity)
        att_embedding=self.attr_embeddings(attribution)
        b = self.b(attribution).squeeze()
        predictions = torch.sum(entity_embedding * att_embedding, dim=-1) + b
        if self.do_sigmoid:
            predictions = torch.sigmoid(predictions)
        return predictions
    
    def predict_all_value(self,attribution):
        entities_ids=torch.arange(self.nentity-1).to(attribution.device)  #除去虚拟节点
        predictions=self.predict_for_test(entities_ids,attribution)
        return predictions
    
    def compute_stdev(self,attribution):
        predictions=self.predict_all_value(attribution)
        stdev = torch.std(predictions, dim=-1, unbiased=False)
        return stdev
        
    
    def score_attribution_exists(self,attribution):  #注意属性为偶数
        dummy_entity=self.ent_embeddings(torch.tensor(self.nentity-1).to(attribution.device))
        attribution_relation=self.rel_embeddings(attribution+self.nrelation-self.nattr).to(attribution.device)
        all_scores = self.score_all_entities(
            dummy_entity.expand(1, 1, -1),
            attribution_relation.expand(1, 1, -1),
            self.ent_embeddings.weight,
        ).squeeze()
        exists_score=all_scores[:-1]
        exists_score_norm=(exists_score - exists_score.min()) / (exists_score.max() - exists_score.min())  #返回所有实体有attribution的得分,不包含虚拟节点
        return exists_score_norm
    
    
    def score_all_entities(self,
                    lhs_emb: Tensor,
                    rel_emb: Tensor,
                    rhs_emb: Tensor):

        predicted = (lhs_emb + rel_emb)

        scores = torch.empty(size=(lhs_emb.shape[0], lhs_emb.shape[1], rhs_emb.shape[0])).to(device=lhs_emb.device)
        for ent in range(lhs_emb.shape[0]):
            for var in range(lhs_emb.shape[1]):
                scores[ent, var] = - torch.norm(predicted[ent, var] - rhs_emb, self.p_norm, -1)

        return scores